import torch
import torchvision.datasets as dsets
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
import pymc as pm Reconstruction Error
It’s natural to seek short cuts
Job Satisfaction Data
import numpy as np
# Standard deviations
stds = np.array([0.939, 1.017, 0.937, 0.562, 0.760, 0.524,
0.585, 0.609, 0.731, 0.711, 1.124, 1.001])
n = len(stds)
# Lower triangular correlation values as a flat list
corr_values = [
1.000,
.668, 1.000,
.635, .599, 1.000,
.263, .261, .164, 1.000,
.290, .315, .247, .486, 1.000,
.207, .245, .231, .251, .449, 1.000,
-.206, -.182, -.195, -.309, -.266, -.142, 1.000,
-.280, -.241, -.238, -.344, -.305, -.230, .753, 1.000,
-.258, -.244, -.185, -.255, -.255, -.215, .554, .587, 1.000,
.080, .096, .094, -.017, .151, .141, -.074, -.111, .016, 1.000,
.061, .028, -.035, -.058, -.051, -.003, -.040, -.040, -.018, .284, 1.000,
.113, .174, .059, .063, .138, .044, -.119, -.073, -.084, .563, .379, 1.000
]
# Fill correlation matrix
corr_matrix = np.zeros((n, n))
idx = 0
for i in range(n):
for j in range(i+1):
corr_matrix[i, j] = corr_values[idx]
corr_matrix[j, i] = corr_values[idx]
idx += 1
# Covariance matrix: Sigma = D * R * D
cov_matrix = np.outer(stds, stds) * corr_matrix
#cov_matrix_test = np.dot(np.dot(np.diag(stds), corr_matrix), np.diag(stds))
columns=["JW1","JW2","JW3", "UF1","UF2","FOR", "DA1","DA2","DA3", "EBA","ST","MI"]
corr_df = pd.DataFrame(corr_matrix, columns=columns)
cov_df = pd.DataFrame(cov_matrix, columns=columns)
cov_df
def make_sample(cov_matrix, size, columns):
sample_df = pd.DataFrame(np.random.multivariate_normal([0]*12, cov_matrix, size=size), columns=columns)
return sample_df
sample_df = make_sample(cov_matrix, 263, columns)
sample_df.head()| JW1 | JW2 | JW3 | UF1 | UF2 | FOR | DA1 | DA2 | DA3 | EBA | ST | MI | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | -0.169852 | -1.192894 | -0.795520 | 0.377262 | 0.210613 | 0.007787 | 0.102942 | -0.566327 | -0.814464 | 1.118203 | 2.504131 | 0.533175 |
| 1 | -0.026490 | -0.481563 | -0.543942 | -0.430214 | -0.916923 | -0.292995 | -0.393448 | -1.053847 | -1.197843 | 1.073452 | 0.358204 | 0.053217 |
| 2 | -0.256757 | -0.995389 | -0.564365 | 1.008702 | -0.394069 | 0.413988 | 0.029353 | -0.368301 | -0.068482 | -0.067150 | 0.143072 | 0.151521 |
| 3 | -0.224012 | 0.869592 | 0.784490 | -0.395174 | -0.103010 | 0.795385 | -0.100986 | 0.145035 | 0.089296 | 0.331997 | 0.052695 | -0.536667 |
| 4 | -0.715423 | -1.714706 | -0.872284 | -0.783565 | -1.012569 | -0.265503 | 0.247473 | 0.283395 | 0.059887 | -0.553747 | 0.395177 | -2.008416 |
data = sample_df.corr()
def plot_heatmap(data, title="Correlation Matrix", vmin=-.2, vmax=.2, ax=None, figsize=(10, 6), colorbar=True):
data_matrix = data.values
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
im = ax.imshow(data, cmap='viridis', vmin=vmin, vmax=vmax)
for i in range(data_matrix.shape[0]):
for j in range(data_matrix.shape[1]):
text = ax.text(
j, i, # x, y coordinates
f"{data_matrix[i, j]:.2f}", # text to display
ha="center", va="center", # center alignment
color="white" if data_matrix[i,j] < 0.5 else "black" # contrast color
)
ax.set_title(title)
ax.set_xticklabels(data.columns)
ax.set_xticks(np.arange(data.shape[1]))
ax.set_yticklabels(data.index)
ax.set_yticks(np.arange(data.shape[0]))
if colorbar:
plt.colorbar(im)
plot_heatmap(data, vmin=-1, vmax=1)/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_yticklabels(data.index)
X = make_sample(cov_matrix, 100, columns=columns)
U, S, VT = np.linalg.svd(X, full_matrices=False)ranks = [2, 5, 12]
reconstructions = []
for k in ranks:
X_k = U[:, :k] @ np.diag(S[:k]) @ VT[:k, :]
reconstructions.append(X_k)
# Plot original and reconstructed matrices
fig, axes = plt.subplots(1, len(ranks) + 1, figsize=(10,15))
axes[0].imshow(X, cmap='viridis')
axes[0].set_title("Original")
axes[0].axis("off")
for ax, k, X_k in zip(axes[1:], ranks, reconstructions):
ax.imshow(X_k, cmap='viridis')
ax.set_title(f"Rank {k}")
ax.axis("off")
plt.suptitle("Reconstruction of Data Using SVD \n various truncation options",fontsize=12, x=.5, y=1.01)
plt.tight_layout()
plt.show()Variational Auto-Encoders
class NumericVAE(nn.Module):
def __init__(self, n_features, hidden_dim=64, latent_dim=8):
super().__init__()
# ---------- ENCODER ----------
# First layer: compress input features into a hidden representation
self.fc1 = nn.Linear(n_features, hidden_dim)
# Latent space parameters (q(z|x)): mean and log-variance
self.fc_mu = nn.Linear(hidden_dim, latent_dim) # μ(x)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim) # log(σ^2(x))
# ---------- DECODER ----------
# First layer: map latent variable z back into hidden representation
self.fc2 = nn.Linear(latent_dim, hidden_dim)
# Output distribution parameters for reconstruction p(x|z)
# For numeric data, we predict both mean and log-variance per feature
self.fc_out_mu = nn.Linear(hidden_dim, n_features) # μ_x(z)
self.fc_out_logvar = nn.Linear(hidden_dim, n_features) # log(σ^2_x(z))
# ENCODER forward pass: input x -> latent mean, log-variance
def encode(self, x):
h = F.relu(self.fc1(x)) # Hidden layer with ReLU
mu = self.fc_mu(h) # Latent mean vector
logvar = self.fc_logvar(h) # Latent log-variance vector
return mu, logvar
# Reparameterization trick: sample z = μ + σ * ε (ε ~ N(0,1))
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar) # σ = exp(0.5 * logvar)
eps = torch.randn_like(std) # ε ~ N(0, I)
return mu + eps * std # z = μ + σ * ε
# DECODER forward pass: latent z -> reconstructed mean, log-variance
def decode(self, z):
h = F.relu(self.fc2(z)) # Hidden layer with ReLU
recon_mu = self.fc_out_mu(h) # Mean of reconstructed features
recon_logvar = self.fc_out_logvar(h)# Log-variance of reconstructed features
return recon_mu, recon_logvar
# Full forward pass: input x -> reconstructed (mean, logvar), latent params
def forward(self, x):
mu, logvar = self.encode(x) # q(z|x)
z = self.reparameterize(mu, logvar) # Sample z from q(z|x)
recon_mu, recon_logvar = self.decode(z)# p(x|z)
return (recon_mu, recon_logvar), mu, logvar
# Sample new synthetic data: z ~ N(0,I), decode to x
def generate(self, n_samples=100):
self.eval()
with torch.no_grad():
# Sample z from standard normal prior
z = torch.randn(n_samples, self.fc_mu.out_features)
# Decode to get reconstruction distribution parameters
cont_mu, cont_logvar = self.decode(z)
# Sample from reconstructed Gaussian: μ_x + σ_x * ε
return cont_mu + torch.exp(0.5 * cont_logvar) * torch.randn_like(cont_mu)def vae_loss(recon_mu, recon_logvar, x, mu, logvar):
# Reconstruction loss: Gaussian log likelihood
recon_var = torch.exp(recon_logvar)
recon_nll = 0.5 * (torch.log(2 * torch.pi * recon_var) + (x - recon_mu) ** 2 / recon_var)
recon_loss = recon_nll.sum(dim=1).mean() # sum over features, mean over batch
# KL divergence: D_KL(q(z|x) || p(z)) where p(z)=N(0,I)
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
kl_loss = kl_div.mean()
return recon_loss + kl_loss, recon_loss, kl_lossdef prep_data_vae(sample_size=1000):
sample_df = make_sample(cov_matrix=cov_matrix, size=sample_size, columns=columns)
X_train, X_test = train_test_split(sample_df.values, test_size=0.2, random_state=890)
X_train = torch.tensor(X_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
train_loader = torch.utils.data.DataLoader(X_train, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(X_test, batch_size=32)
return train_loader, test_loader
def train_vae(vae, optimizer, train, test, patience=10, wait=0, n_epochs=1000):
best_loss = float('inf')
losses = []
for epoch in range(n_epochs):
vae.train()
train_loss = 0.0
for batch in train:
optimizer.zero_grad()
(recon_mu, recon_logvar), mu, logvar = vae(batch)
loss, recon_loss, kl_loss = vae_loss(recon_mu, recon_logvar, batch, mu, logvar)
loss.backward()
optimizer.step()
train_loss += loss.item() * batch.size(0)
avg_train_loss = train_loss / train.dataset.shape[0]
# --- Test Loop ---
vae.eval()
test_loss = 0.0
with torch.no_grad():
for batch in test:
(recon_mu, recon_logvar), mu, logvar = vae(batch)
loss, _, _ = vae_loss(recon_mu, recon_logvar, batch, mu, logvar)
test_loss += loss.item() * batch.size(0)
avg_test_loss = test_loss / test.dataset.shape[0]
print(f"Epoch {epoch+1}/{n_epochs} | Train Loss: {avg_train_loss:.4f} | Test Loss: {avg_test_loss:.4f}")
if avg_test_loss < best_loss - 1e-4:
best_loss, wait = avg_test_loss, 0
best_state = vae.state_dict()
else:
wait += 1
if wait >= patience:
print(f"Early stopping at epoch {epoch+1}")
vae.load_state_dict(best_state) # restore best
break
losses.append([avg_train_loss, avg_test_loss, best_loss])
return vae, pd.DataFrame(losses, columns=['train_loss', 'test_loss', 'best_loss'])
train_500, test_500 = prep_data_vae(500)
vae = NumericVAE(n_features=train_500.dataset.shape[1], hidden_dim=64)
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
vae_fit_500, losses_df_500 = train_vae(vae, optimizer, train_500, test_500)
train_1000, test_1000 = prep_data_vae(1000)
vae = NumericVAE(n_features=train_1000.dataset.shape[1], hidden_dim=64)
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
vae_fit_1000, losses_df_1000 = train_vae(vae, optimizer, train_1000, test_1000)
train_10_000, test_10_000 = prep_data_vae(10_000)
vae = NumericVAE(n_features=train_10_000.dataset.shape[1], hidden_dim=64)
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
vae_fit_10_000, losses_df_10_000 = train_vae(vae, optimizer, train_10_000, test_10_000)fig, axs = plt.subplots(1, 3, figsize=(8, 6))
axs=axs.flatten()
losses_df_500[['train_loss', 'test_loss']].plot(ax=axs[0])
losses_df_1000[['train_loss', 'test_loss']].plot(ax=axs[1])
losses_df_10_000[['train_loss', 'test_loss']].plot(ax=axs[2])
axs[0].set_title("Training and Test Losses \n 500 observations");
axs[1].set_title("Training and Test Losses \n 1000 observations");
axs[2].set_title("Training and Test Losses \n 10_000 observations");def bootstrap_residuals(vae_fit, X_test, sample_df, n_boot=1000):
recons = []
resid_array = np.zeros((n_boot, len(sample_df.columns), len(sample_df.columns)))
for i in range(n_boot):
recon_data = vae_fit.generate(n_samples=len(X_test))
reconstructed_df = pd.DataFrame(recon_data, columns=sample_df.columns)
resid = pd.DataFrame(X_test, columns=sample_df.columns).corr() - reconstructed_df.corr()
resid_array[i] = resid.values
recons.append(reconstructed_df)
avg_resid = resid_array.mean(axis=0)
bootstrapped_resids = pd.DataFrame(avg_resid, columns=sample_df.columns, index=sample_df.columns)
return bootstrapped_resids
bootstrapped_resids_500 = bootstrap_residuals(vae_fit_500, pd.DataFrame(test_500.dataset, columns=sample_df.columns), sample_df)
bootstrapped_resids_1000 = bootstrap_residuals(vae_fit_1000, pd.DataFrame(test_1000.dataset, columns=sample_df.columns), sample_df)
bootstrapped_resids_10_000 = bootstrap_residuals(vae_fit_10_000, pd.DataFrame(test_10_000.dataset, columns=sample_df.columns), sample_df)
fig, axs = plt.subplots(3, 1, figsize=(10, 20))
axs = axs.flatten()
plot_heatmap(bootstrapped_resids_500, title="""Expected Correlation Residuals for 500 observations \n Under 1000 Bootstrapped Reconstructions""", ax=axs[0], colorbar=True, vmin=-.25, vmax=.25)
plot_heatmap(bootstrapped_resids_1000, title="""Expected Correlation Residuals for 1000 observations \n Under 1000 Bootstrapped Reconstructions""", ax=axs[1], colorbar=True, vmin=-.25, vmax=.25)
plot_heatmap(bootstrapped_resids_10_000, title="""Expected Correlation Residuals for 10,000 observations \n Under 1000 Bootstrapped Reconstructions""", ax=axs[2], colorbar=True, vmin=-.25, vmax=.25)/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_yticklabels(data.index)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_yticklabels(data.index)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_yticklabels(data.index)
Missing Data
sample_df_missing = sample_df.copy()
# Randomly pick 5% of the total elements
mask_remove = np.random.rand(*sample_df_missing.shape) < 0.05
# Set those elements to NaN
sample_df_missing[mask_remove] = np.nan
sample_df_missing.head()| JW1 | JW2 | JW3 | UF1 | UF2 | FOR | DA1 | DA2 | DA3 | EBA | ST | MI | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | -0.169852 | -1.192894 | -0.795520 | NaN | 0.210613 | 0.007787 | 0.102942 | -0.566327 | -0.814464 | 1.118203 | 2.504131 | 0.533175 |
| 1 | -0.026490 | -0.481563 | -0.543942 | -0.430214 | -0.916923 | -0.292995 | -0.393448 | -1.053847 | -1.197843 | 1.073452 | 0.358204 | 0.053217 |
| 2 | -0.256757 | -0.995389 | -0.564365 | 1.008702 | -0.394069 | 0.413988 | 0.029353 | -0.368301 | -0.068482 | -0.067150 | NaN | 0.151521 |
| 3 | -0.224012 | 0.869592 | 0.784490 | -0.395174 | -0.103010 | 0.795385 | -0.100986 | 0.145035 | 0.089296 | 0.331997 | 0.052695 | -0.536667 |
| 4 | -0.715423 | -1.714706 | -0.872284 | -0.783565 | -1.012569 | -0.265503 | 0.247473 | 0.283395 | 0.059887 | -0.553747 | 0.395177 | -2.008416 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class MissingDataDataset(Dataset):
def __init__(self, x, mask):
# x and mask are tensors of same shape
self.x = x
self.mask = mask
def __len__(self):
return self.x.shape[0]
def __getitem__(self, idx):
return self.x[idx], self.mask[idx]
def prep_data_vae_missing(sample_size=1000, batch_size=32, missing_frac=0.2):
sample_df = make_sample(cov_matrix=cov_matrix, size=sample_size, columns=columns)
total_values = sample_df.size
num_nans = int(total_values * missing_frac)
# Choose random flat indices
nan_indices = np.random.choice(total_values, num_nans, replace=False)
# Convert flat indices to (row, col)
rows, cols = np.unravel_index(nan_indices, sample_df.shape)
# Set the values to NaN
sample_df.values[rows, cols] = np.nan
X_train, X_test = train_test_split(sample_df.values, test_size=0.2, random_state=890)
# Mask: 1=observed, 0=missing
mask_train = ~pd.DataFrame(X_train).isna()
mask_test = ~pd.DataFrame(X_test).isna()
# Tensors (keep NaNs for missing values)
x_train_tensor = torch.tensor(X_train, dtype=torch.float32)
mask_train_tensor = torch.tensor(mask_train.values, dtype=torch.float32)
x_test_tensor = torch.tensor(X_test, dtype=torch.float32)
mask_test_tensor = torch.tensor(mask_test.values, dtype=torch.float32)
train_dataset = MissingDataDataset(x_train_tensor, mask_train_tensor)
test_dataset = MissingDataDataset(x_test_tensor, mask_test_tensor)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
return train_loader, test_loader
def vae_loss_missing(recon_mu, recon_logvar, x_filled, mu, logvar, mask):
"""
VAE loss that skips missing values (NaNs) in x for the reconstruction term.
"""
# Reconstruction loss (Gaussian NLL) only on observed values
recon_var = torch.exp(recon_logvar)
recon_nll = 0.5 * (torch.log(2 * torch.pi * recon_var) + (x_filled - recon_mu) ** 2 / recon_var)
# Apply mask and normalize by number of observed features per sample
recon_nll = recon_nll * mask # zero-out missing features
obs_counts = mask.sum(dim=1).clamp(min=1) # avoid division by 0
recon_loss = (recon_nll.sum(dim=1) / obs_counts).mean()
# KL divergence as usual
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
kl_loss = kl_div.mean()
return recon_loss, kl_lossimport torch
import torch.nn as nn
import torch.nn.functional as F
class NumericVAE_missing(nn.Module):
def __init__(self, n_features, hidden_dim=64, latent_dim=8):
super().__init__()
self.n_features = n_features
# ---------- Learnable Imputation ----------
# One learnable parameter per feature for missing values
self.missing_embeddings = nn.Parameter(torch.zeros(n_features))
# ---------- ENCODER ----------
self.fc1_x = nn.Linear(n_features, hidden_dim)
# Stronger mask encoder: 2-layer MLP
self.fc1_mask = nn.Sequential(
nn.Linear(n_features, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
# Combine feature and mask embeddings
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
# ---------- DECODER ----------
self.fc2 = nn.Linear(latent_dim, hidden_dim)
self.fc_out_mu = nn.Linear(hidden_dim, n_features)
self.fc_out_logvar = nn.Linear(hidden_dim, n_features)
def encode(self, x, mask):
# Impute missing values with learnable parameters
x_filled = torch.where(
torch.isnan(x),
self.missing_embeddings.expand_as(x),
x
)
# Encode features and mask separately
h_x = F.relu(self.fc1_x(x_filled))
h_mask = self.fc1_mask(mask)
# Combine embeddings
h = h_x + h_mask
# Latent space
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
h = F.relu(self.fc2(z))
recon_mu = self.fc_out_mu(h)
recon_logvar = self.fc_out_logvar(h)
return recon_mu, recon_logvar
def forward(self, x, mask):
mu, logvar = self.encode(x, mask)
z = self.reparameterize(mu, logvar)
recon_mu, recon_logvar = self.decode(z)
return (recon_mu, recon_logvar), mu, logvar
def generate(self, n_samples=100):
self.eval()
with torch.no_grad():
z = torch.randn(n_samples, self.fc_mu.out_features)
recon_mu, recon_logvar = self.decode(z)
return recon_mu + torch.exp(0.5 * recon_logvar) * torch.randn_like(recon_mu)def vae_loss_missing(recon_mu, recon_logvar, x, mu, logvar, mask):
# Fill missing values with 0 just for loss computation
x_filled = torch.where(mask.bool(), x, torch.zeros_like(x))
recon_var = torch.exp(recon_logvar)
recon_nll = 0.5 * (torch.log(2 * torch.pi * recon_var) +
(x_filled - recon_mu) ** 2 / recon_var)
# Mask out missing values
recon_nll = recon_nll * mask
obs_counts = mask.sum(dim=1).clamp(min=1)
recon_loss = (recon_nll.sum(dim=1) / obs_counts).mean()
# KL divergence
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
kl_loss = kl_div.mean()
return recon_loss, kl_loss
def beta_annealing(epoch, max_beta=1.0, anneal_epochs=100):
beta = min(max_beta, max_beta * epoch / anneal_epochs)
return betadef fit_vae_missing(vae_missing, train_loader, test_loader, optimizer, patience=10, wait=0, n_epochs=1000):
best_loss = float('inf')
losses = []
for epoch in range(n_epochs):
beta = beta_annealing(epoch, max_beta=1.0, anneal_epochs=10)
vae_missing.train()
train_loss = 0
for x_batch, mask_batch in train_loader:
optimizer.zero_grad()
(recon_mu, recon_logvar), mu, logvar = vae_missing(x_batch, mask_batch)
recon_loss, kl_loss = vae_loss_missing(recon_mu, recon_logvar, x_batch, mu, logvar, mask_batch)
loss = recon_loss + beta * kl_loss
loss.backward()
optimizer.step()
train_loss += loss.item() * x_batch.size(0)
avg_train_loss = train_loss / len(train_loader.dataset)
# --- Validation ---
vae_missing.eval()
test_loss = 0.0
with torch.no_grad():
for x_batch, mask_batch in test_loader:
(recon_mu, recon_logvar), mu, logvar = vae_missing(x_batch, mask_batch)
recon_loss, kl_loss = vae_loss_missing(recon_mu, recon_logvar, x_batch, mu, logvar, mask_batch)
loss = recon_loss + kl_loss
test_loss += loss.item() * x_batch.size(0)
avg_test_loss = test_loss / len(test_loader.dataset)
print(f"Epoch {epoch+1}/{n_epochs} | Train: {avg_train_loss:.4f} | Test: {avg_test_loss:.4f}")
# Early stopping
if avg_test_loss < best_loss - 1e-4:
best_loss, wait = avg_test_loss, 0
best_state = vae_missing.state_dict()
else:
wait += 1
if wait >= patience:
print(f"Early stopping at epoch {epoch+1}")
vae_missing.load_state_dict(best_state) # restore best
break
losses.append([avg_train_loss, avg_test_loss, best_loss])
losses_df = pd.DataFrame(losses, columns=['train_loss', 'test_loss', 'best_loss'])
return vae_missing, losses_df
train_loader_5000, test_loader_5000 = prep_data_vae_missing(5000, batch_size=32)
vae_missing_5000 = NumericVAE_missing(n_features=next(iter(train_loader_5000))[0].shape[1], latent_dim=12)
optimizer = optim.Adam(vae_missing_5000.parameters(), lr=1e-4)
fit_vae_missing_5000, losses_df_missing_5000 = fit_vae_missing(vae_missing_5000, train_loader_5000, test_loader_5000, optimizer)
train_loader_25_000, test_loader_25_000 = prep_data_vae_missing(25_000, batch_size=32)
vae_missing_25_000 = NumericVAE_missing(n_features=next(iter(train_loader_25_000))[0].shape[1], latent_dim=12)
optimizer = optim.Adam(vae_missing_25_000.parameters(), lr=1e-4)
fit_vae_missing_25_000, losses_df_missing_25_000 = fit_vae_missing(vae_missing_25_000, train_loader_25_000, test_loader_25_000, optimizer)bootstrapped_resids_5000 = bootstrap_residuals(fit_vae_missing_5000, pd.DataFrame(test_loader_5000.dataset.x, columns=sample_df.columns), sample_df)
bootstrapped_resids_25_000 = bootstrap_residuals(fit_vae_missing_25_000, pd.DataFrame(test_loader_25_000.dataset.x, columns=sample_df.columns), sample_df)
fig, axs = plt.subplots(2, 1, figsize=(10, 20))
axs = axs.flatten()
plot_heatmap(bootstrapped_resids_5000, title="""Expected Correlation Residuals for 5000 observations \n Under 1000 Bootstrapped Reconstructions""", ax=axs[0], colorbar=True, vmin=-.25, vmax=.25)
plot_heatmap(bootstrapped_resids_25_000, title="""Expected Correlation Residuals for 25,000 observations \n Under 1000 Bootstrapped Reconstructions""", ax=axs[1], colorbar=True, vmin=-.25, vmax=.25)/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_yticklabels(data.index)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_yticklabels(data.index)
Bayesian Inference
def make_pymc_model(sample_df):
coords = {'features': sample_df.columns,
'features1': sample_df.columns ,
'obs': range(len(sample_df))}
with pm.Model(coords=coords) as model:
# Priors
mus = pm.Normal("mus", 0, 1, dims='features')
chol, _, _ = pm.LKJCholeskyCov("chol", n=12, eta=1.0, sd_dist=pm.HalfNormal.dist(1))
cov = pm.Deterministic('cov', pm.math.dot(chol, chol.T), dims=('features', 'features1'))
pm.MvNormal('likelihood', mus, cov=cov, observed=sample_df.values, dims=('obs', 'features'))
idata = pm.sample_prior_predictive()
idata.extend(pm.sample(random_seed=120))
pm.sample_posterior_predictive(idata, extend_inferencedata=True)
return idata, model
idata, model = make_pymc_model(sample_df)pm.model_to_graphviz(model)import arviz as az
expected_corr = pd.DataFrame(az.summary(idata, var_names=['chol_corr'])['mean'].values.reshape((12, 12)), columns=sample_df.columns, index=sample_df.columns)
resids = sample_df.corr() - expected_corr
plot_heatmap(resids)/Users/nathanielforde/mambaforge/envs/pytorch-env/lib/python3.10/site-packages/arviz/stats/diagnostics.py:596: RuntimeWarning: invalid value encountered in scalar divide
(between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
/Users/nathanielforde/mambaforge/envs/pytorch-env/lib/python3.10/site-packages/arviz/stats/diagnostics.py:991: RuntimeWarning: invalid value encountered in scalar divide
varsd = varvar / evar / 4
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_yticklabels(data.index)
Missing Data
idata_missing, model_missing = make_pymc_model(sample_df_missing)pm.model_to_graphviz(model_missing)expected_corr = pd.DataFrame(az.summary(idata_missing, var_names=['chol_corr'])['mean'].values.reshape((12, 12)), columns=sample_df.columns, index=sample_df.columns)
resids = sample_df.corr() - expected_corr
plot_heatmap(resids)/Users/nathanielforde/mambaforge/envs/pytorch-env/lib/python3.10/site-packages/arviz/stats/diagnostics.py:596: RuntimeWarning: invalid value encountered in scalar divide
(between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
/Users/nathanielforde/mambaforge/envs/pytorch-env/lib/python3.10/site-packages/arviz/stats/diagnostics.py:991: RuntimeWarning: invalid value encountered in scalar divide
varsd = varvar / evar / 4
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_47216/4216500546.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
ax.set_yticklabels(data.index)
Citation
BibTeX citation:
@online{forde2025,
author = {Forde, Nathaniel},
title = {Amortized {Bayesian} {Inference} with {PyTorch}},
date = {2025-07-25},
langid = {en},
abstract = {The cost of generating new sample data can be prohibitive.
There is a secondary but different cost which attaches to the
“construction” of novel data. Principal Components Analysis can be
seen as a technique to optimally reconstruct a complex multivariate
data set from a lower level compressed dimensional space.
Variational auto-encoders allow us to achieve yet more flexible
reconstruction results in non-linear cases. Drawing a new sample
from the posterior predictive distribution of Bayesian models
similarly supplies us with insight in the variability of realised
data. Both methods assume a latent model of the data generating
process that aims to leverage a compressed representation of the
data. These are different heuristics with different consequences for
how we understand the variability in the world. Both encode
different degrees of inductive bias throught architectural
specification but of the two, Bayesian methods do so transparently.}
}
For attribution, please cite this work as:
Forde, Nathaniel. 2025. “Amortized Bayesian Inference with
PyTorch.” July 25, 2025.